# Multi-label classification on COCO images using CocoDetection metadata.
# Requires pycocotools installed.
import argparse
import os
import torch
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.datasets import CocoDetection

from MK_CAViT import mk_cavit_base
from train_utils import set_seed, train_one_epoch, evaluate_cls, coco_multilabel_collate


def build_coco_loaders(root: str, ann_train: str, ann_val: str,
                       img_size: int, batch: int, workers: int):
    tf = T.Compose([
        T.Resize(int(img_size * 1.14)),
        T.CenterCrop(img_size),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    tr = CocoDetection(os.path.join(root, 'train2017'), ann_train, transform=tf)
    va = CocoDetection(os.path.join(root, 'val2017'),   ann_val,   transform=tf)
    tr_loader = DataLoader(tr, batch_size=batch, shuffle=True, num_workers=workers,
                           pin_memory=True, collate_fn=coco_multilabel_collate)
    va_loader = DataLoader(va, batch_size=batch, shuffle=False, num_workers=workers,
                           pin_memory=True, collate_fn=coco_multilabel_collate)
    return tr_loader, va_loader


def main(args):
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = mk_cavit_base(num_classes=80, img_size=args.size).to(device)

    tr_loader, va_loader = build_coco_loaders(
        args.root, args.ann_train, args.ann_val, args.size, args.batch_size, args.workers
    )
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.05)
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    for epoch in range(args.epochs):
        tr = train_one_epoch(model, tr_loader, optimizer, device, task='multilabel', mu=args.mu, scaler=scaler)
        ev = evaluate_cls(model, va_loader, device, task='multilabel')
        print(f"[{epoch+1:03d}/{args.epochs:03d}] train loss {tr['loss']:.4f} | val loss {ev['loss']:.4f}")

    torch.save(model.state_dict(), args.out)


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument('--root', type=str, required=True, help='COCO root containing train2017/ and val2017/')
    p.add_argument('--ann_train', type=str, required=True, help='instances_train2017.json path')
    p.add_argument('--ann_val', type=str, required=True, help='instances_val2017.json path')
    p.add_argument('--size', type=int, default=224)
    p.add_argument('--epochs', type=int, default=30)
    p.add_argument('--batch_size', type=int, default=64)
    p.add_argument('--workers', type=int, default=8)
    p.add_argument('--lr', type=float, default=3e-4)
    p.add_argument('--mu', type=float, default=0.1)
    p.add_argument('--amp', action='store_true')
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--out', type=str, default='mk_cavit_coco.pth')
    main(p.parse_args())
